% Parameters
num_firms = 110; % Number of firms
T = 120; % Number of time periods (months)
num_states = 2; % Two regimes: Normal and High-Volatility
num_iters = 100; % Maximum number of EM iterations
tol = 1e-6; % Convergence tolerance

% Example simulation (replace with actual data loading)
prices = randn(T, num_firms) * 50 + 100; % Simulated prices
jumps = abs(diff([NaN(T, 1), prices])) > 10; % Jumps based on price changes
prices(rand(T, num_firms) < 0.1) = NaN; % Introduce random missing data


% Extract valid price changes for a single firm as an example
firm_prices = prices(:, 1); % Example firm data
valid_idx = ~isnan(firm_prices);

if sum(valid_idx) > num_states + 10 % Ensure sufficient data for estimation
    % Calculate price changes
    price_changes = diff(firm_prices(valid_idx));
    T = length(price_changes);

    % Initialize HMM parameters
    rng(1); % For reproducibility
    initial_probs = [0.5, 0.5]; % Start with equal probabilities
    transition_matrix = [0.9, 0.1; 0.1, 0.9]; % Initial guess
    means = [mean(price_changes), mean(price_changes) + std(price_changes)]; % State-dependent means
    variances = [var(price_changes), var(price_changes) * 2]; % State-dependent variances

    % EM Algorithm
    for iter = 1:num_iters
        % E-Step: Calculate forward and backward probabilities
        [alpha, c] = forward(price_changes, initial_probs, transition_matrix, means, variances);
        beta = backward(price_changes, transition_matrix, means, variances, c);

        % Posterior probabilities
        gamma = alpha .* beta; % State probabilities
        gamma = gamma ./ sum(gamma, 2); % Normalize

        % Pairwise state probabilities
        xi = compute_xi(price_changes, alpha, beta, transition_matrix, means, variances);

        % M-Step: Update parameters
        initial_probs = gamma(1, :); % Update initial state probabilities
        transition_matrix = sum(xi, 3) ./ sum(gamma(1:end-1, :), 1)'; % Update transition matrix

        for s = 1:num_states
            means(s) = sum(gamma(:, s) .* price_changes') / sum(gamma(:, s)); % Update means
            variances(s) = sum(gamma(:, s) .* (price_changes' - means(s)).^2) / sum(gamma(:, s)); % Update variances
        end

        % Check for convergence
        if iter > 1 && max(abs(transition_matrix(:) - prev_trans(:))) < tol
            fprintf('EM converged at iteration %d\n', iter);
            break;
        end
        prev_trans = transition_matrix;
    end

    % Decode states using Viterbi algorithm
    states = viterbi(price_changes, initial_probs, transition_matrix, means, variances);

    % Display results
    disp('Estimated Transition Matrix:');
    disp(transition_matrix);
    disp('Estimated Means:');
    disp(means);
    disp('Estimated Variances:');
    disp(variances);
else
    fprintf('Insufficient data for GHMM estimation.\n');
end

% Helper Functions
function [alpha, c] = forward(obs, init_probs, trans, means, vars)
    % Forward algorithm for GHMM
    T = length(obs);
    num_states = length(init_probs);
    alpha = zeros(T, num_states);
    c = zeros(T, 1); % Scaling factors

    % Initialization
    for s = 1:num_states
        alpha(1, s) = init_probs(s) * normpdf(obs(1), means(s), sqrt(vars(s)));
    end
    c(1) = sum(alpha(1, :));
    alpha(1, :) = alpha(1, :) / c(1);

    % Recursion
    for t = 2:T
        for s = 1:num_states
            alpha(t, s) = normpdf(obs(t), means(s), sqrt(vars(s))) * ...
                          sum(alpha(t-1, :) .* trans(:, s)');
        end
        c(t) = sum(alpha(t, :));
        alpha(t, :) = alpha(t, :) / c(t);
    end
end

function beta = backward(obs, trans, means, vars, c)
    % Backward algorithm for GHMM
    T = length(obs);
    num_states = size(trans, 1);
    beta = zeros(T, num_states);

    % Initialization
    beta(T, :) = 1 / c(T);

    % Recursion
    for t = T-1:-1:1
        for s = 1:num_states
            beta(t, s) = sum(trans(s, :) .* normpdf(obs(t+1), means, sqrt(vars)) .* beta(t+1, :)) / c(t);
        end
    end
end

function xi = compute_xi(obs, alpha, beta, trans, means, vars)
    % Compute pairwise state probabilities (xi)
    T = length(obs);
    num_states = size(trans, 1);
    xi = zeros(num_states, num_states, T-1);

    for t = 1:T-1
        denom = sum(sum(alpha(t, :)' .* trans .* ...
                normpdf(obs(t+1), means, sqrt(vars)) .* beta(t+1, :)));
        for i = 1:num_states
            for j = 1:num_states
                xi(i, j, t) = alpha(t, i) * trans(i, j) * ...
                              normpdf(obs(t+1), means(j), sqrt(vars(j))) * beta(t+1, j) / denom;
            end
        end
    end
end

function states = viterbi(obs, init_probs, trans, means, vars)
    % Viterbi algorithm for GHMM
    T = length(obs);
    num_states = length(init_probs);
    delta = zeros(T, num_states);
    psi = zeros(T, num_states);

    % Initialization
    for s = 1:num_states
        delta(1, s) = log(init_probs(s)) + log(normpdf(obs(1), means(s), sqrt(vars(s))));
    end

    % Recursion
    for t = 2:T
        for s = 1:num_states
            [delta(t, s), psi(t, s)] = max(delta(t-1, :) + log(trans(:, s)'));
            delta(t, s) = delta(t, s) + log(normpdf(obs(t), means(s), sqrt(vars(s))));
        end
    end

    % Termination
    [~, states(T)] = max(delta(T, :));

    % Backtrack
    for t = T-1:-1:1
        states(t) = psi(t+1, states(t+1));
    end
end

%----- Panel of Firms ----------------------------
% Initialize storage for results
estimated_states = NaN(T, num_firms); % Inferred states
transition_matrices = NaN(num_states, num_states, num_firms); % Transition matrices

for firm = 1:num_firms
    % Extract firm's data and handle missing observations
    firm_prices = prices(:, firm);
    valid_idx = ~isnan(firm_prices);
    
    if sum(valid_idx) > num_states + 10 % Ensure sufficient data
        % Calculate price changes
        price_changes = diff(firm_prices(valid_idx));
        
        % Initialize HMM parameters
        initial_probabilities = [0.5, 0.5];
        transition_matrix = [0.9, 0.1; 0.1, 0.9];
        emission_means = [mean(price_changes), mean(price_changes) + std(price_changes)];
        emission_stds = [std(price_changes), std(price_changes) * 2];

        % Estimate HMM using hmmtrain
        try
            [est_trans, est_emission] = hmmtrain(price_changes, ...
                transition_matrix, [emission_means; emission_stds]);
            
            % Decode states
            states = hmmdecode(price_changes, est_trans, est_emission);
            
            % Store results
            transition_matrices(:, :, firm) = est_trans;
            
            % Align states to the original timeline
            states_full = NaN(T - 1, 1);
            states_full(valid_idx(2:end)) = states; % Align to original time series
            estimated_states(:, firm) = [NaN; states_full];
        catch ME
            fprintf('HMM estimation failed for firm %d: %s\n', firm, ME.message);
        end
    else
        fprintf('Insufficient data for firm %d\n', firm);
    end
end

% Save results
save('HMM_Estimation_Results.mat', 'estimated_states', 'transition_matrices');
disp('HMM estimation complete.');


% -------------- Post-Jump Analysis
% Initialize storage for post-jump analysis
post_jump_states = NaN(sum(jumps, 'all'), 5); % States in the 5 periods post-jump

for firm = 1:num_firms
    % Find jumps for the firm
    firm_jumps = find(jumps(:, firm));
    
    for i = 1:length(firm_jumps)
        idx = firm_jumps(i);
        if idx + 5 <= T && ~any(isnan(estimated_states(idx:idx+5, firm)))
            % Store 5 periods of states after the jump
            post_jump_states(i, :) = estimated_states(idx:idx+4, firm)';
        end
    end
end

% Aggregate post-jump regime frequencies
regime_frequencies = mean(post_jump_states == 2, 'omitnan'); % Proportion of high-volatility regime

% Plot regime frequencies
figure;
bar(1:5, regime_frequencies);
xlabel('Periods After Jump');
ylabel('Proportion of High-Volatility Regime');
title('Proportion of High-Volatility Regime After Jumps');
grid on;

% ---------- Transition Matrixes -------------------------------
% Calculate mean transition matrix across firms
valid_firms = ~isnan(transition_matrices(1, 1, :));
mean_transition_matrix = mean(transition_matrices(:, :, valid_firms), 3, 'omitnan');

% Display transition matrix
disp('Mean Transition Matrix Across Firms:');
disp(mean_transition_matrix);

% Persistence in high-volatility regime
persistence_high_vol = mean(transition_matrices(2, 2, valid_firms), 'omitnan');

disp('Persistence in High-Volatility Regime:');
disp(persistence_high_vol);

% Plot persistence across firms
firm_persistence = squeeze(transition_matrices(2, 2, valid_firms));
figure;
histogram(firm_persistence, 20);
xlabel('Persistence in High-Volatility Regime');
ylabel('Number of Firms');
title('Distribution of High-Volatility Persistence Across Firms');
grid on;
